import os
import logging
import torch

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
from train_argument import parser, print_args

import random
import copy
import numpy as np

from time import time
from model import net
from utils import *
from Simulator import Simulator
from Split_Data import Non_iid_split_fmnist, Non_iid_split_cifar, data_stats


def main(args):
    save_folder = args.affix
    
    log_folder = os.path.join(args.log_root, save_folder) #return a new path 
    model_folder = os.path.join(args.model_root, save_folder)

    makedirs(log_folder)
    makedirs(model_folder)


    setattr(args, 'log_folder', log_folder) #setattr(obj, var, val) assign object attribute to its value, just like args.'log_folder' = log_folder
    setattr(args, 'model_folder', model_folder)

    logger = create_logger(log_folder, 'train', 'info')
    print_args(args, logger) #It prints arguments


    sequence_length = 28 
    input_size = 28
    num_layers = 2
    num_classes = 10


    if args.dataset =='fmnist':
        tr_dataset = torchvision.datasets.FashionMNIST(args.data_root, 
                                        train=True, 
                                        transform=torchvision.transforms.ToTensor(), 
                                        download=True)

        # evaluation during training
        te_dataset = torchvision.datasets.FashionMNIST(args.data_root, 
                                        train=False, 
                                        transform=torchvision.transforms.ToTensor(), 
                                        download=True)    
        num_classes = 10
        
        Non_iid_tr_datasets, Non_iid_te_datasets = Non_iid_split_fmnist(
                num_classes, args.num_clients, tr_dataset, te_dataset, args.alpha)
        local_tr_data_loaders = [DataLoader(dataset, num_workers = 0,
                                            batch_size = args.batch_size, 
                                            shuffle = True)
                        for dataset in Non_iid_tr_datasets]
        local_te_data_loaders = [DataLoader(dataset, num_workers = 0,
                                            batch_size = args.batch_size, 
                                            shuffle = True)
                        for dataset in Non_iid_te_datasets]
        
        client_data_counts, client_total_samples = data_stats(Non_iid_tr_datasets, 10, args.num_clients)
        client_te_data_counts, client_total_te_samples = data_stats(Non_iid_te_datasets, 10, args.num_clients)
        
        while np.min(client_total_te_samples) < args.batch_size and np.min(client_total_samples) < args.batch_size: #if a batch has only one sample, then we have an error in BN layers
            print('reloading data.....', np.min(client_total_samples, client_total_te_samples))
            Non_iid_tr_datasets, Non_iid_te_datasets = Non_iid_split_fmnist(
            10, args.num_clients, tr_dataset, te_dataset, args.alpha)
            client_data_counts, client_total_samples = data_stats(Non_iid_tr_datasets, 10, args.num_clients)
            client_te_data_counts, client_total_te_samples = data_stats(Non_iid_te_datasets, 10, args.num_clients)

        

    local_tr_data_loaders = [DataLoader(dataset, num_workers = 0,
                                        batch_size = args.batch_size, 
                                        shuffle = True, drop_last=True)
                    for dataset in Non_iid_tr_datasets]
    local_te_data_loaders = [DataLoader(dataset, num_workers = 0,
                                        batch_size = args.batch_size, 
                                        shuffle = True, drop_last=True)
                    for dataset in Non_iid_te_datasets]


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("currrent device: ", device)

    if args.mask == 1:
        model = net.LSTM_model(input_size, args.hidden_size, args.num_layers, num_classes, mask=args.mask).to(device) 
        logger.info(model)

        trainer = Simulator(args, logger, local_tr_data_loaders, local_te_data_loaders, device)
        trainer.initialization(copy.deepcopy(model))
        trainer.FL_loop()

if __name__ == '__main__':
    args = parser()
    # print_args(args)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    main(args)